import numpy as np
import pandas as pd
from pathlib import Path


def preprocess(df_all=None, x_column_names=None, y_column_names=None,
               g_column_name=None):
    '''Preprocess the dataset of Obermeyer et al. (2019)'''
    
    # load dataset
    if df_all is None:
        df_all, x_column_names = load_data_obermeyer()
        y_column_names = ['cost_t', 'gagne_sum_t']
        g_column_name = 'dem_race_black'
        x_column_names.remove(g_column_name)
    
    # remove redundant columns if any
    x_remove = []
    for x in x_column_names:
        temp = df_all[x].unique()
        if len(temp) == 1:
            x_remove.append(x)
    x_column_names = set(x_column_names) - set(x_remove)
    print('removed columns:', x_remove)
    x_column_names = list(x_column_names)
    column_names = list(set(y_column_names) | set(x_column_names) |
                        set([g_column_name]))
    # For reproducibility, we sort the list.
    column_names.sort()
    df_all = df_all[column_names].copy()
    
    return df_all


def load_data_obermeyer():
    '''Load the dataset of Obermeyer et al. (2019)
    Note that `x_column_names` includes race variable `dem_race_black`.
    '''
    df = load_data_df()
    df_all, x_column_names, _ = get_Y_x_df(df, verbose=False)
    return df_all, x_column_names


def load_data_df():
    """Load data dataframe.

    Returns
    -------
    pd.DataFrame
        DataFrame to use for analysis.

    """
    path_home = Path(__file__).resolve().parent.parent
    data_fp = path_home.joinpath('data/data_new.csv')

    # load df
    data_df = pd.read_csv(data_fp)

    # because we removed patient
    data_df = data_df.reset_index()
    return data_df


def get_Y_x_df(df, verbose):
    """Get dataframe with relevant x and Y columns.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.
    verbose : bool
        Print statistics of features.

    Returns
    -------
    all_Y_x_df : pd.DataFrame
        Dataframe with x (features) and y (labels) columns
    x_column_names : list
        List of all x column names (features).
        NB: 'risk_score_t' is not included
    Y_predictors : list
        All labels (Y) to predict.
    
    Notes
    -----
    The original dataset contains 10 time-t variables.
    This function selects 5 of them ('risk_score_t', 'program_enrolled_t',
    'cost_t', 'cost_avoidable_t', and 'gagne_sum_t'), and add three columns
    ('index', 'log_cost_t', and 'log_cost_avoidable_t').
    Therefore, the output all_Y_x_df has 158 (= 160-5+3) columns.
    """
    # cohort columns
    cohort_cols = ['index']

    # features (x)
    x_column_names = get_all_features(df, verbose)

    # include log columns
    df['log_cost_t'] = convert_to_log(df, 'cost_t')
    df['log_cost_avoidable_t'] = convert_to_log(df, 'cost_avoidable_t')

    # labels (Y) to predict
    Y_predictors = ['log_cost_t', 'gagne_sum_t', 'log_cost_avoidable_t']

    # redefine 'race' variable as indicator
    df['dem_race_black'] = np.where(df['race'] == 'black', 1, 0)

    # additional metrics used for table 2 and table 3
    table_metrics = ['risk_score_t', 'program_enrolled_t',
                     'cost_t', 'cost_avoidable_t']

    # combine all features together -- this forms the Y_x df
    all_Y_x_df = df[cohort_cols + x_column_names + Y_predictors + table_metrics].copy()

    return all_Y_x_df, x_column_names, Y_predictors


def get_all_features(df, verbose=False):
    """Get all features.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.
    verbose : bool
        Print statistics of features.

    Returns
    -------
    x_column_names: list of str
        List of all features.

    Remark
    ------
    Unlike the original code, `dem_race_black` is included in
    `x_column_names`
    """
    dem_features = get_dem_features(df)
    comorbidity_features = get_comorbidity_features(df)
    cost_features = get_cost_features(df)
    lab_features = get_lab_features(df)
    med_features = get_med_features(df)

    x_column_names = dem_features + comorbidity_features + cost_features + \
                     lab_features + med_features

    if verbose:
        print('Features breakdown:')
        print('   {}: {}'.format('demographic', len(dem_features)))
        print('   {}: {}'.format('comorbidity', len(comorbidity_features)))
        print('   {}: {}'.format('cost', len(cost_features)))
        print('   {}: {}'.format('lab', len(lab_features)))
        print('   {}: {}'.format('med', len(med_features)))
        print(' {}: {}'.format('TOTAL', len(x_column_names)))

    return x_column_names


def get_dem_features(df):
    """Get demographic features.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.

    Returns
    -------
    list
        List of demographic features.

    Remark
    ------
    group variable (`dem_race_black`) was not included in the output `dem_features`
    in the original code. I made a change to include the group variable.
    """
    dem_features = []
    prefix = 'dem_'
    for col in df.columns:
        if prefix == col[:len(prefix)]:
            dem_features.append(col)

    dem_features.append('dem_race_black') # added
    return dem_features


def get_comorbidity_features(df):
    """Get comorbidity features.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.

    Returns
    -------
    list
        List of comorbidity features.

    """
    comorbidity_features = []
    comorbidity_sum = 'gagne_sum_tm1'
    suffix_elixhauser = '_elixhauser_tm1'
    suffix_romano = '_romano_tm1'

    for col in df.columns:
        if col == comorbidity_sum:
            comorbidity_features.append(col)
        elif suffix_elixhauser == col[-len(suffix_elixhauser):]:
            comorbidity_features.append(col)
        elif suffix_romano == col[-len(suffix_romano):]:
            comorbidity_features.append(col)
        else:
            continue
    return comorbidity_features


def get_cost_features(df):
    """Get cost features.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.

    Returns
    -------
    list
        List of cost features.

    """
    cost_features = []
    prefix = 'cost_'
    for col in df.columns:
        if prefix == col[:len(prefix)]:
            # 'cost_t', 'cost_avoidable_t' are outcomes, not a features
            if col not in ['cost_t', 'cost_avoidable_t']:
                cost_features.append(col)
    return cost_features


def get_lab_features(df):
    """Get lab features.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.

    Returns
    -------
    list
        List of lab features.

    """
    lab_features = []
    suffix_labs_counts = '_tests_tm1'
    suffix_labs_low = '-low_tm1'
    suffix_labs_high = '-high_tm1'
    suffix_labs_normal = '-normal_tm1'
    for col in df.columns:
        # get lab features
        if suffix_labs_counts == col[-len(suffix_labs_counts):]:
            lab_features.append(col)
        elif suffix_labs_low == col[-len(suffix_labs_low):]:
            lab_features.append(col)
        elif suffix_labs_high == col[-len(suffix_labs_high):]:
            lab_features.append(col)
        elif suffix_labs_normal == col[-len(suffix_labs_normal):]:
            lab_features.append(col)
        else:
            continue
    return lab_features


def get_med_features(df):
    """Get med features.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.

    Returns
    -------
    list
        List of med features.

    """
    med_features = []
    prefix = 'lasix_'
    for col in df.columns:
        if prefix == col[:len(prefix)]:
            med_features.append(col)
    return med_features


def convert_to_log(df, col_name):
    """Convert column to log space.

    Defining log as log(x + EPSILON) to avoid division by zero.

    Parameters
    ----------
    df : pd.DataFrame
        Data dataframe.
    col_name : str
        Name of column in df to convert to log.

    Returns
    -------
    np.ndarray
        Values of column in log space

    """
    # This is to avoid division by zero while doing np.log10
    EPSILON = 1
    return np.log10(df[col_name].values + EPSILON)